import collections
import numpy as np 
from common.functions import softmax,cross_entropy_error

class MatMul:
    def __init__(self,W):
        self.params = [W]
        self.grads = [np.zeros_like(W)]
        self.x = None 

    def forward(self,x):
        W, = self.params
        out = np.dot(x,W)
        self.x = x 
        return out 
    
    def backward(self,dout):
        W, = self.params
        dx = np.dot(dout,W.T)
        dW = np.dot(self.x.T,dout)
        self.grads[0][...] = dW 
        return dx 
    
class Affine:
    def __init__(self,W,b):
        self.params = [W,b]
        self.grads =[np.zeros_like(W),np.zeros_like(b)]
        self.x = None 
    
    def forward(self,x):
        W,b = self.params
        self.x = x 
        out = np.dot(x,W) + b 
        return out 

    def backward(self,dout):
        W,b = self.params
        dx = np.dot(dout,W.T)
        dW = np.dot(self.x.T,dout)
        db = np.sum(dout,axis=0)

        self.grads[0][...] = dW
        self.grads[1][...] = db 
        return dx 
    
class Sigmoid:
    def __init__(self):
        self.params,self.grads = [],[]
        self.out = None

    def forward(self,x):
        out = 1/(1+np.exp(-x))
        self.out = out 
        return out 
    
    def backward(self,dout):
        dx = dout *(1.0-self.out)*self.out 
        return dx 
    
class SigmoidWithLoss:
    def __init__(self):
        self.params,self.grads = [],[]
        self.loss = None 
        self.y = None # sigmoid输出
        self.t = None # 监督标签

    def forward(self,x,t):
        self.t = t 
        self.y = 1/(1+np.exp(-x))

        self.loss = cross_entropy_error(np.c_[1-self.y,self.y],self.t)

        return self.loss
    
    def backward(self,dout=1):
        batch_size = self.t.shape[0]
        dx = (self.y-self.t)*dout/batch_size
        return dx 
    
class Softmax:
    def __init__(self):
        self.params,self.grads = [],[]
        self.out = None

    def forward(self,x):
        self.out = softmax(x)
        return self.out 
    
    def backward(self,dout):
        dx = self.out*dout 
        sumdx = np.sum(dx,axis=1,keepdims=True)
        dx -= self.out*sumdx

        return dx 
    
    
class SoftmaxWithLoss:
    def __init__(self):
        self.params, self.grads = [], []
        self.y = None  # softmax的输出
        self.t = None  # 监督标签

    def forward(self, x, t):
        self.t = t
        self.y = softmax(x)

        # 在监督标签为one-hot向量的情况下，转换为正确解标签的索引
        if self.t.size == self.y.size:
            self.t = self.t.argmax(axis=1)

        loss = cross_entropy_error(self.y, self.t)
        return loss

    def backward(self, dout=1):
        batch_size = self.t.shape[0]

        dx = self.y.copy()
        dx[np.arange(batch_size), self.t] -= 1
        dx *= dout
        dx = dx / batch_size

        return dx
    
class Embedding:
    def __init__(self,W):
        self.params = [W]
        self.grads = [np.zeros_like(W)]
        self.idx = None 
    
    def forward(self,idx):
        W, = self.params
        self.idx = idx 
        out = W[idx]
        return out 
    
    def backward(self,dout):
        dW, = self.grads 
        dW[...] = 0

        np.add.at(dW,self.idx,dout)

        return None
   

